{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 注册插件\n",
    "\n",
    "参考：[将 TVM 集成到您的项目中](https://xinetzone.github.io/tvm/docs/how_to/deploy/integrate.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{func}`tvm.register_extension` 用于将自定义类注册为 TVM（Tensor Virtual Machine）扩展类的函数。通过注册这个类可以作为 TVM 生成的函数的参数直接传递。以下是详细解读：\n",
    "\n",
    "- **目的**：将自定义类注册为 TVM 的扩展类，使其能够作为 TVM 生成的函数的参数。\n",
    "- **核心要求**：注册的类必须包含名为 `_tvm_handle` 的属性，用于返回表示句柄地址的整数值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te\n",
    "import numpy as np\n",
    "\n",
    "@tvm.register_extension\n",
    "class MyTensorView:\n",
    "    _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE\n",
    "\n",
    "    def __init__(self, arr):\n",
    "        self.arr = arr\n",
    "\n",
    "    @property\n",
    "    def _tvm_handle(self):\n",
    "        return self.arr._tvm_handle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可这样："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import Sequence\n",
    "import tvm\n",
    "from tvm import te\n",
    "import numpy as np\n",
    "\n",
    "@tvm.register_extension\n",
    "@dataclass\n",
    "class MyTensorView:\n",
    "    arr: Sequence\n",
    "    _tvm_tcode: int = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE\n",
    "\n",
    "    @property\n",
    "    def _tvm_handle(self):\n",
    "        return self.arr._tvm_handle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DLTensor 兼容性\n",
    "\n",
    "DLTensor 兼容性 是指一个类或数据结构能够与 TVM 中的 DLTensor 类型无缝交互。DLTensor 是 TVM 中用于表示张量（Tensor）的核心数据结构，它与深度学习框架（如 PyTorch、TensorFlow）中的张量类似，但具有更高的灵活性和跨平台支持。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = \"int64\"\n",
    "n = te.var(\"n\")\n",
    "Ab = tvm.tir.decl_buffer((n,), dtype)\n",
    "i = te.var(\"i\")\n",
    "ib = tvm.tir.ir_builder.create()\n",
    "A = ib.buffer_ptr(Ab)\n",
    "with ib.for_range(0, n - 1, \"i\") as i:\n",
    "    A[i + 1] = A[i] + 1\n",
    "stmt = ib.get()\n",
    "\n",
    "mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr(\"global_symbol\", \"arange\"))\n",
    "f = tvm.build(mod, target=\"stackvm\")\n",
    "a = tvm.nd.array(np.zeros(10, dtype=dtype))\n",
    "aview = MyTensorView(a)\n",
    "f(aview)\n",
    "np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(MyTensorView(arr=<tvm.nd.NDArray shape=(10,), cpu(0)>\n",
       " array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), _tvm_tcode=7),\n",
       " tvm.runtime.ndarray.NDArray)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aview, type(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时 `MyTensorView` 接受 {class}`tvm.nd.NDArray` 作为输入，返回 {class}`tvm.nd.NDArray` 作为输出。但是："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(MyTensorView(arr=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), _tvm_tcode=7),\n",
       " __main__.MyTensorView)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = MyTensorView(np.zeros(10, dtype=dtype))\n",
    "b, type(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
